import numpy as np
import json
import random as rd
from parser1 import args

class DataLoader(object):
    def __init__(self,path,batch_size):
        super(DataLoader, self).__init__()
        self.batch_size = batch_size

        train_file = path + '/train.json'#+ '/%d-core/train.json' % (args.core)
        val_file = path + '/val.json' #+ '/%d-core/val.json' % (args.core)
        test_file = path + '/test.json' #+ '/%d-core/test.json'  % (args.core)

        #get number of users and items
        self.n_users, self.n_items,self.n_val = 0, 0,0
        self.n_train, self.n_test = 0, 0
        self.neg_pools = {}
        self.max_inter = 0
        self.exist_users = []

        train = json.load(open(train_file))
        test = json.load(open(test_file))
        val = json.load(open(val_file))

        self.train_items, self.test_set, self.val_set = {}, {}, {}
        self.test_items,self.val_items=[],[]
        for uid, items in train.items():
            if len(items) == 0:
                continue
            uid = int(uid)
            self.exist_users.append(uid)
            self.n_items = max(self.n_items, max(items))
            self.n_users = max(self.n_users, uid)
            self.max_inter = max(self.max_inter,len(items))
            self.n_train += len(items)
            self.train_items[uid] = items

        for uid, items in test.items():
            uid = int(uid)
            try:
                self.n_items = max(self.n_items, max(items))
                self.n_test += len(items)
                self.test_set[uid] = items
                self.test_items.append(self.train_items[uid])
                self.max_inter = max(self.max_inter, len(items))
            except:
                continue
        for uid, items in val.items():
            uid = int(uid)
            try:
                self.n_items = max(self.n_items, max(items))
                self.n_val += len(items)
                self.val_set[uid] = items
                self.val_items.append(self.train_items[uid])
                self.max_inter = max(self.max_inter, len(items))
            except:
                continue

        self.n_items += 1
        self.n_users += 1


    def sample(self):
            if self.batch_size <= self.n_users:
                users = rd.sample(self.exist_users, self.batch_size)
            else:
                users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]
            #self.max_inter = 0
            # users = self.exist_users[:]

            def sample_pos_items_for_u(u, num):
                pos_items = self.train_items[u]
                n_pos_items = len(pos_items)
                #self.max_inter = max(self.max_inter,n_pos_items)
                pos_batch = []
                while True:
                    if len(pos_batch) == num: break
                    pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
                    pos_i_id = pos_items[pos_id]

                    if pos_i_id not in pos_batch:
                        pos_batch.append(pos_i_id)
                return pos_batch


            pos_items, train_items = [], []
            for u in users:
                pos_items += sample_pos_items_for_u(u, 1)
                # neg_items += sample_neg_items_for_u(u, 3)
                train_items.append(self.train_items[u])
            return users, pos_items,train_items,self.max_inter

    def GCN_sample(self):
        if self.batch_size <= self.n_users:
            users = rd.sample(self.exist_users, self.batch_size)
        else:
            users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]

        # users = self.exist_users[:]

        def sample_pos_items_for_u(u, num):
            pos_items = self.train_items[u]
            n_pos_items = len(pos_items)
            pos_batch = []
            while True:
                if len(pos_batch) == num: break
                pos_id = np.random.randint(low=0, high=n_pos_items, size=1)[0]
                pos_i_id = pos_items[pos_id]

                if pos_i_id not in pos_batch:
                    pos_batch.append(pos_i_id)
            return pos_batch

        def sample_neg_items_for_u(u, num):
            neg_items = []
            while True:
                if len(neg_items) == num: break
                neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
                if neg_id not in self.train_items[u] and neg_id not in neg_items:
                    neg_items.append(neg_id)
            return neg_items

        def sample_neg_items_for_u_from_pools(u, num):
            neg_items = list(set(self.neg_pools[u]) - set(self.train_items[u]))
            return rd.sample(neg_items, num)

        pos_items, neg_items = [], []
        for u in users:
            pos_items += sample_pos_items_for_u(u, 1)
            neg_items += sample_neg_items_for_u(u, 1)
            # neg_items += sample_neg_items_for_u(u, 3)
        return users, pos_items, neg_items